{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8162093e-99cb-4646-93be-9e3a63eecc84",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import shutil\n",
    "import os\n",
    "import glob\n",
    "import numpy as np\n",
    "import cv2\n",
    "from functools import partial\n",
    "from PIL import Image, ImageDraw, ImageFont\n",
    "from io import BytesIO\n",
    "import IPython\n",
    "from sklearn.cluster import DBSCAN\n",
    "\n",
    "from plsc.engine.inference import Predictor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cd617b7-4ace-4d03-ac67-ed6c65b96e00",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download models and assets\n",
    "!mkdir -p models\n",
    "if not os.path.exists('models/blazeface_fpn_ssh_1000e_v1.0_infer/inference.pdmodel'):\n",
    "    !wget https://paddle-model-ecology.bj.bcebos.com/model/insight-face/blazeface_fpn_ssh_1000e_v1.0_infer.tar -P models/\n",
    "    !tar -xzf models/blazeface_fpn_ssh_1000e_v1.0_infer.tar -C models/\n",
    "    !rm -rf models/blazeface_fpn_ssh_1000e_v1.0_infer.tar\n",
    "    \n",
    "if not os.path.exists('models/FaceViT_tiny_patch9_112_infer/FaceViT_tiny_patch9_112.pdmodel'):\n",
    "    !wget https://plsc.bj.bcebos.com/models/face/v2.4/FaceViT_tiny_patch9_112_infer.tgz -P models/\n",
    "    !tar -xzf models/FaceViT_tiny_patch9_112_infer.tgz -C models/\n",
    "    !rm -rf models/FaceViT_tiny_patch9_112_infer.tgz\n",
    "    \n",
    "if not os.path.exists('images'):\n",
    "    !mkdir -p images\n",
    "    !wget https://plsc.bj.bcebos.com/dataset/BigBang.tgz -P images\n",
    "    !tar -xzf images/BigBang.tgz --strip-components 1 -C images\n",
    "    !rm -rf images/BigBang.tgz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee26a767-0870-4059-86c4-d855f38434fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw(img, box_list):\n",
    "    im = Image.fromarray(img)\n",
    "    draw = ImageDraw.Draw(im)\n",
    "\n",
    "    for i, dt in enumerate(box_list):\n",
    "        bbox, score = dt[2:], dt[1]\n",
    "        color = 'red'\n",
    "\n",
    "        xmin, ymin, xmax, ymax = bbox\n",
    "        draw.rectangle(\n",
    "            [(xmin, ymin), (xmax, ymax)], width=2, outline=color)\n",
    "    return im\n",
    "\n",
    "def display_img_array(img):\n",
    "    bio = BytesIO()\n",
    "    img.save(bio, format='png')\n",
    "    IPython.display.display(IPython.display.Image(bio.getvalue(), format='png'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce6ef4da-d1f7-4f3a-9a25-52b29cf7aa39",
   "metadata": {},
   "outputs": [],
   "source": [
    "def facedetect_preprocess_fn(img, target_size=[640, 640]):\n",
    "    resize_h, resize_w = target_size\n",
    "    img_shape = img.shape\n",
    "    img_scale_x = resize_w / img_shape[1]\n",
    "    img_scale_y = resize_h / img_shape[0]\n",
    "    img = cv2.resize(\n",
    "        img, None, None, fx=img_scale_x, fy=img_scale_y, interpolation=1)\n",
    "    \n",
    "    scale = 1. / 255.\n",
    "    mean = np.array([[[0.485, 0.456, 0.406]]])\n",
    "    std = np.array([[[0.229, 0.224, 0.225]]])\n",
    "\n",
    "    img = (img.astype('float32') * scale - mean) / std\n",
    "    img_info = {}\n",
    "    img_info[\"im_shape\"] = np.array(\n",
    "        img.shape[:2], dtype=np.float32)[np.newaxis, :]\n",
    "    img_info[\"scale_factor\"] = np.array(\n",
    "        [img_scale_y, img_scale_x], dtype=np.float32)[np.newaxis, :]\n",
    "\n",
    "    img = img.transpose((2, 0, 1)).copy()\n",
    "    img_info[\"image\"] = img[np.newaxis, :, :, :].astype(np.float32)\n",
    "    return img_info\n",
    "\n",
    "def facedetect_postprocess_fn(outputs, thresh=0.8):\n",
    "    np_boxes = outputs[0]\n",
    "    expect_boxes = (np_boxes[:, 1] > thresh) & (np_boxes[:, 0] > -1)\n",
    "    return np_boxes[expect_boxes, :]\n",
    "\n",
    "face_detector = Predictor(\n",
    "    model_file='models/blazeface_fpn_ssh_1000e_v1.0_infer/inference.pdmodel',\n",
    "    params_file='models/blazeface_fpn_ssh_1000e_v1.0_infer/inference.pdiparams',\n",
    "    preprocess_fn=facedetect_preprocess_fn,\n",
    "    postprocess_fn=facedetect_postprocess_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94d384f8-da7b-4dc8-8bd0-2e6ce9c63b01",
   "metadata": {},
   "outputs": [],
   "source": [
    "def facerecog_preprocess_fn(img):\n",
    "    scale = 1.0 / 255.0\n",
    "    mean = 0.5\n",
    "    std = 0.5\n",
    "    img = (img.astype('float32') * scale - mean) / std\n",
    "    img = img[:, :, ::-1]\n",
    "    img = img.transpose((0, 3, 1, 2))\n",
    "\n",
    "    return {'inputs': img}\n",
    "\n",
    "def crop_face(img, box_list):\n",
    "    batch = []\n",
    "    for idx, box in enumerate(box_list):\n",
    "        box[box < 0] = 0\n",
    "        xmin, ymin, xmax, ymax = list(map(int, box[2:]))\n",
    "        w = xmax - xmin + 1\n",
    "        h = ymax - ymin + 1\n",
    "        radius = int(round(max(h, w) / 2.0))\n",
    "        cx = int(round((xmax + xmin) / 2.0))\n",
    "        cy = int(round((ymax + ymin) / 2.0))\n",
    "        xmin = cx - radius\n",
    "        xmax = cx + radius\n",
    "        ymin = cy - radius\n",
    "        ymax = cy + radius\n",
    "        \n",
    "        face_img = img[ymin:ymax, xmin:xmax, :]\n",
    "        face_img = cv2.resize(face_img, (112, 112)).copy()\n",
    "        batch.append(face_img)\n",
    "    return np.stack(batch)\n",
    "\n",
    "face_recog = Predictor(\n",
    "    model_file='models/FaceViT_tiny_patch9_112_infer/FaceViT_tiny_patch9_112.pdmodel',\n",
    "    params_file='models/FaceViT_tiny_patch9_112_infer/FaceViT_tiny_patch9_112.pdiparams',\n",
    "    preprocess_fn=facerecog_preprocess_fn,\n",
    "    postprocess_fn=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eeea0e11-c565-44ee-a10c-395c8c59097c",
   "metadata": {},
   "outputs": [],
   "source": [
    "feats_list = []\n",
    "fileid_list = []\n",
    "boxes_list = []\n",
    "\n",
    "filenames = glob.glob('images/*.png')\n",
    "for idx, filename in enumerate(filenames):\n",
    "    img = cv2.imread(filename)\n",
    "    boxes = face_detector.predict(img)\n",
    "\n",
    "    faces = crop_face(img, boxes)\n",
    "    feats = face_recog.predict(faces)\n",
    "    \n",
    "    feats_list.append(feats[0])\n",
    "    fileid = np.empty(faces.shape[0], dtype=np.int32)\n",
    "    fileid.fill(idx)\n",
    "    fileid_list.append(fileid)\n",
    "    boxes_list.append(boxes)\n",
    "    \n",
    "face_feat = np.concatenate(feats_list, axis=0)\n",
    "face_file = np.concatenate(fileid_list, axis=0)\n",
    "face_boxes = np.concatenate(boxes_list, axis=0)\n",
    "\n",
    "X = face_feat / np.linalg.norm(face_feat, axis=-1, keepdims=True)\n",
    "\n",
    "db = DBSCAN(eps=0.5, min_samples=2, metric=\"cosine\").fit(X) ##metric默认是欧式距离\n",
    "core_samples_mask = np.zeros_like(db.labels_, dtype=bool)\n",
    "core_samples_mask[db.core_sample_indices_] = True\n",
    "labels = db.labels_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2c8f5a7-73a6-4341-b794-f11f35960ec6",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_image = True\n",
    "copy_image = False\n",
    "\n",
    "clusters = set(labels)\n",
    "output_root = 'clusters'\n",
    "for clusters_id in clusters:\n",
    "    # noise cluster\n",
    "    # if int(clusters_id) == -1:\n",
    "    #     continue\n",
    "    face_idx = np.where(labels == clusters_id)\n",
    "    \n",
    "    sel_fileids = face_file[face_idx]\n",
    "    sel_boxes = face_boxes[face_idx]\n",
    "    print()\n",
    "    print('='*20, f'face id {clusters_id}', '='*20)\n",
    "    for idx in range(sel_fileids.shape[0]):\n",
    "        filename = filenames[sel_fileids[idx]]\n",
    "        img = cv2.imread(filename)\n",
    "        img_drawed = draw(img[:,:,::-1], [sel_boxes[idx]])\n",
    "        \n",
    "        if show_image:\n",
    "            display_img_array(img_drawed)\n",
    "\n",
    "        if copy_image:\n",
    "            output_dir = os.path.join(output_root, str(clusters_id))\n",
    "            if not os.path.exists(output_dir):\n",
    "                os.makedirs(output_dir)\n",
    "            shutil.copyfile(filename, os.path.join(output_dir, filename.split('/')[-1]))\n",
    "\n",
    "            if idx == 0:\n",
    "                cropped = crop_face(img, [sel_boxes[idx]])[0]\n",
    "                cv2.imwrite(os.path.join(output_dir, 'thumbnail.png'), cropped)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
